Skip to content

Conversation

@Tianyue-Zhao
Copy link
Contributor

This addresses the requests for CogVLM in #4387 and #4350.
CogVLM is a pretty popular model that now adds in cleanly after the recent additions to libmtmd.
I've converted a GGUF here: Link to GGUF files

Sample command and output:

build/bin/llama-mtmd-cli -m ../cogvlm-chat-hf/cogvlm-13B-chat-v1.1-F16.gguf --mmproj ../cogvlm-chat-hf/mmproj-cogvlm-chat-hf --image ./community.png --chat-template vicuna -p "Describe the picture"

load_hparams: model size:         8448.53 MiB
load_hparams: metadata size:      0.36 MiB
alloc_compute_meta:        CPU compute buffer size =   142.02 MiB
main: loading model: ../cogvlm-chat-hf/cogvlm-13B-chat-v1.1-F16.gguf
encoding image slice...
image slice encoded in 16135 ms
decoding image batch 1/1, n_tokens_batch = 1227
image decoded (batch 1/1) in 54065 ms

1. The image showcases a futuristic urban landscape with a mix of architectural styles. The buildings are multi-storied and have a combination of traditional and modern elements. There's a prominent tree in the foreground, suggesting a blend of nature and urban development. The scene appears to be bustling with activity, with various signs and billboards, indicating commercial or residential zones.


llama_perf_context_print:        load time =  108969.65 ms
llama_perf_context_print: prompt eval time =   85229.27 ms /  1241 tokens (   68.68 ms per token,    14.56 tokens per second)
llama_perf_context_print:        eval time =   19843.15 ms /    83 runs   (  239.07 ms per token,     4.18 tokens per second)
llama_perf_context_print:       total time =  126951.23 ms /  1324 tokens
llama_perf_context_print:    graphs reused =          0

@github-actions github-actions bot added examples python python script changes labels Aug 1, 2025
@Tianyue-Zhao Tianyue-Zhao marked this pull request as ready for review August 1, 2025 02:15
@Tianyue-Zhao
Copy link
Contributor Author

I think I've fixed the typecheck and format check workflows that were failing before, can someone approve the workflows to run again?
Also, is there a way to run these Github workflows locally or without needing approval from a reviewer?
It would be good to run these CI/CD checks myself before posting the PR.

@CISC
Copy link
Collaborator

CISC commented Aug 2, 2025

Also, is there a way to run these Github workflows locally or without needing approval from a reviewer? It would be good to run these CI/CD checks myself before posting the PR.

You can run flake8, pyright and editorconfig locally (or via IDE plugins), the build tests can be run manually with ctest.

Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not a complete review as I don't know enough about mtmd, just commenting...

@Tianyue-Zhao
Copy link
Contributor Author

Also, is there a way to run these Github workflows locally or without needing approval from a reviewer? It would be good to run these CI/CD checks myself before posting the PR.

You can run flake8, pyright and editorconfig locally (or via IDE plugins), the build tests can be run manually with ctest.

Thanks for the info! That's something I've been wondering about for a while.

Copy link
Collaborator

@CISC CISC left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Further refinement (merge cont+reshape).

@CISC
Copy link
Collaborator

CISC commented Aug 29, 2025

Further refinement (merge cont+reshape).

After #15662 we can avoid these altogether and just create 3D views.

@CISC
Copy link
Collaborator

CISC commented Sep 8, 2025

Further refinement (merge cont+reshape).

After #15662 we can avoid these altogether and just create 3D views.

Merged, rebase and apply updated suggestions.

@Tianyue-Zhao
Copy link
Contributor Author

Further refinement (merge cont+reshape).

After #15662 we can avoid these altogether and just create 3D views.

Merged, rebase and apply updated suggestions.

Thanks for the reminder, I've rebased it and removed the extra ggml_cont calls.

@ngxson
Copy link
Collaborator

ngxson commented Oct 29, 2025

sorry I missed the notification to review this. will have a look & push commits to resolve the conflicts

Copy link
Collaborator

@ngxson ngxson left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't have enough VRAM to test the model right now, but I think the code should be good to merge (after CI passed)

Feel free to give it a try even after the PR is merged. In case there are bugs, we can make follow-up PRs to fix it.

@ngxson
Copy link
Collaborator

ngxson commented Oct 30, 2025

No idea why the ASAN test failed, probably just a random runtime issue. I'm re-running the CI

@CISC
Copy link
Collaborator

CISC commented Oct 30, 2025

No idea why the ASAN test failed, probably just a random runtime issue. I'm re-running the CI

It's ccache getting poisoned somehow, I've yet to track down the reason, when it happens you have to find and delete all the caches on the branch + master (I've just deleted the master ones) and rerun.

@ngxson ngxson merged commit bacddc0 into ggml-org:master Oct 30, 2025
189 of 192 checks passed
@ngxson
Copy link
Collaborator

ngxson commented Oct 30, 2025

btw @Tianyue-Zhao , it seems like this implementation still use the legacy llava preprocessing and does not support dynamic resolution. is this expected?

SamuelOliveirads pushed a commit to SamuelOliveirads/llama.cpp that referenced this pull request Dec 29, 2025
* model : Granite docling + Idefics3 preprocessing (SmolVLM) (ggml-org#16206)

* feat: Add granite-docling conversion using trillion pretokenizer

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add granite-docling vocab pre enum

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Use granite-docling pre

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add clip_is_idefics3

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Allow multi-token boundary sequences for image templating

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Add tiling support for idefices3 in clip.cpp

This should likely be moved into llava_uhd::get_slice_instructions, but for
now this avoids disrupting the logic there.

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Partial support for full templating for idefics3 in mtmd

There are still errors encoding some of the image chunks, but the token
sequence now matches transformers _almost_ perfectly, except for the double
newline before the global image which shows up as two consecutive newline
tokens instead of a single double-newline token. I think this is happening
because the blocks are tokenized separately then concatenated.

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Fully working image preprocessing for idefics3 w/ resize and slicing

Branch: gabe-l-hart/GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* feat: Parse the preprocessor config's longest side and add it to the mmproj hparams

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Use the longest side instead of size * scale_factor

For Granite Docling, these come out to the same value, but that was just a
conicidence.

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* fix: Allow batch encoding and remove clip_is_idefics3

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* refactor: Remove unnecessary conditionals for empty token vectors

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* refactor: Use image_manipulation util

Branch: GraniteDocling

Signed-off-by: Gabe Goodhart <[email protected]>

* add test model

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Xuan Son Nguyen <[email protected]>
# Conflicts:
#	convert_hf_to_gguf.py
#	convert_hf_to_gguf_update.py
#	gguf-py/gguf/constants.py
#	gguf-py/gguf/gguf_writer.py
#	src/llama-vocab.cpp
#	src/llama-vocab.h

* mtmd : support home-cooked Mistral Small Omni (ggml-org#14928)

* model : add LightOnOCR-1B model (ggml-org#16764)

* model : add LightOnOCR-1B model

* add test
# Conflicts:
#	convert_hf_to_gguf.py
#	gguf-py/gguf/constants.py

* mtmd : fix idefics3 preprocessing (ggml-org#16806)

* mtmd : fix idefics3 preprocessing

* disable granite test

* fix test for granite

* model: Add support for CogVLM model (ggml-org#15002)

* Added GGUF mappings for CogVLM model

* Add tensor mapping for CogVLM visual encoder

* Add CogVLM to conversion script, no vision part yet

* Added CogVLM vision model to conversion script

* Add graph for CogVLM CLIP model

* Add graph for CogVLM

* Fixes for CogVLM. Now compiles.

* Model now runs

* Fixes for cogvlm graph

* Account for graph context change after rebase

* Changes for whitespace

* Changes in convert script according to comments

* Switch CogVLM LLM graph to merged QKV tensor

* Use rope_type variable instead of direct definition

* Change CogVLM CLIP encoder to use SWIGLU

* Switch CogVLM CLIP to use merged QKV

* Apply rebase edits and remove ggml_cont call that is now unnecessary

* clean up

---------

Co-authored-by: Xuan Son Nguyen <[email protected]>
# Conflicts:
#	convert_hf_to_gguf.py
#	examples/mtmd/clip.cpp
#	gguf-py/gguf/constants.py
#	gguf-py/gguf/tensor_mapping.py
#	src/llama-arch.cpp
#	src/llama-arch.h
#	src/llama-model.cpp
#	src/llama-model.h

* mtmd: refactor preprocessing + support max/min pixels (ggml-org#16878)

* mtmd: refactor preprocessing + support max/min pixels

* fix mlp type

* implement mix/max pixels

* improve hparams

* better image preproc for qwen

* fix

* fix out of bound composite

* fix (2)

* fix token calculation

* get_merge_kernel_size()

* fix llama4 and lfm2

* gonna fix them all

* use simple resize for qwen

* qwen: increase min tokens

* no resize if dst size == src size

* restore to initial min/max tokens value for qwen
# Conflicts:
#	examples/mtmd/clip.cpp

* clip : use FA (ggml-org#16837)

* clip : use FA

* cont : add warning about unsupported ops

* implement "auto" mode for clip flash attn

* clip : print more detailed op support info during warmup

* cont : remove obsolete comment [no ci]

* improve debugging message

* trailing space

* metal : remove stray return

---------

Co-authored-by: Xuan Son Nguyen <[email protected]>

* model: add Janus Pro for image understanding (ggml-org#16906)

* Add support for Janus Pro

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Update gguf-py/gguf/tensor_mapping.py

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Address reviewer suggestions

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* Add JANUS_PRO constant

* Update clip model handling

Co-authored-by: Xuan-Son Nguyen <[email protected]>

* Update tools/mtmd/clip.cpp

Co-authored-by: Xuan-Son Nguyen <[email protected]>

* Refactor JANUS_PRO handling in clip.cpp

Co-authored-by: Xuan-Son Nguyen <[email protected]>

* Update tools/mtmd/clip.cpp

Co-authored-by: Sigbjørn Skjæret <[email protected]>

* em whitespace

---------

Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Xuan-Son Nguyen <[email protected]>
Co-authored-by: Xuan-Son Nguyen <[email protected]>
# Conflicts:
#	convert_hf_to_gguf.py
#	gguf-py/gguf/constants.py
#	gguf-py/gguf/tensor_mapping.py

* mtmd: pad mask for qwen2.5vl (ggml-org#16954)

* mtmd: pad mask for qwen2.5vl

* improve

* mtmd: add --image-min/max-tokens (ggml-org#16921)

* mtmd: improve struct initialization (ggml-org#16981)

* mtmd: allow QwenVL to process larger image by default (ggml-org#17020)

* Disable flash attention

* mtmd : fix embedding size for image input (ggml-org#17123)

* mtmd: fix patch_size initialized to random value in audio models (ggml-org#17128)

* mtmd: fix patch_size initialized to random value in audio models

* add default hparams

* add llama_model_n_embd_inp

* Fix load qwen3 vl

Change batch size

* Add description

* Fix cli build error

---------

Signed-off-by: Gabe Goodhart <[email protected]>
Co-authored-by: Gabe Goodhart <[email protected]>
Co-authored-by: Xuan Son Nguyen <[email protected]>
Co-authored-by: Tianyue-Zhao <[email protected]>
Co-authored-by: Georgi Gerganov <[email protected]>
Co-authored-by: Zhiyong Wang <[email protected]>
Co-authored-by: Sigbjørn Skjæret <[email protected]>
Co-authored-by: Xuan-Son Nguyen <[email protected]>
Co-authored-by: firecoperana <firecoperana>
Anico2 added a commit to Anico2/llama.cpp that referenced this pull request Jan 15, 2026
* Added GGUF mappings for CogVLM model

* Add tensor mapping for CogVLM visual encoder

* Add CogVLM to conversion script, no vision part yet

* Added CogVLM vision model to conversion script

* Add graph for CogVLM CLIP model

* Add graph for CogVLM

* Fixes for CogVLM. Now compiles.

* Model now runs

* Fixes for cogvlm graph

* Account for graph context change after rebase

* Changes for whitespace

* Changes in convert script according to comments

* Switch CogVLM LLM graph to merged QKV tensor

* Use rope_type variable instead of direct definition

* Change CogVLM CLIP encoder to use SWIGLU

* Switch CogVLM CLIP to use merged QKV

* Apply rebase edits and remove ggml_cont call that is now unnecessary

* clean up

---------

Co-authored-by: Xuan Son Nguyen <[email protected]>
blime4 pushed a commit to blime4/llama.cpp that referenced this pull request Feb 5, 2026
* Added GGUF mappings for CogVLM model

* Add tensor mapping for CogVLM visual encoder

* Add CogVLM to conversion script, no vision part yet

* Added CogVLM vision model to conversion script

* Add graph for CogVLM CLIP model

* Add graph for CogVLM

* Fixes for CogVLM. Now compiles.

* Model now runs

* Fixes for cogvlm graph

* Account for graph context change after rebase

* Changes for whitespace

* Changes in convert script according to comments

* Switch CogVLM LLM graph to merged QKV tensor

* Use rope_type variable instead of direct definition

* Change CogVLM CLIP encoder to use SWIGLU

* Switch CogVLM CLIP to use merged QKV

* Apply rebase edits and remove ggml_cont call that is now unnecessary

* clean up

---------

Co-authored-by: Xuan Son Nguyen <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

examples python python script changes

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants